07-21-2021
torch.fft as the fastest optionMultiple ways to compute FFT in Python:
numpy.fftscipy.fftpack -- Considered legacy. SciPy recommends using scipy.fftscipy.fftpyfftwcupy.fft -- NVIDIA's library that evaluates FFTs on GPUstorch.FFTscipy.fft is slightly faster than numpy.fft, so using that in this notebookimport sys
print('Environment: {}'.format(sys.exec_prefix))
print('Executable: {}'.format(sys.executable))
print('Version: {}'.format(sys.version))
Environment: /anaconda3/envs/fourier Executable: /anaconda3/envs/fourier/bin/python Version: 3.9.4 (default, Apr 9 2021, 09:32:38) [Clang 10.0.0 ]
import numpy as np
import scipy.misc
from scipy import ndimage, signal
from scipy.fft import fft, ifft, fftn, ifftn, fftshift, ifftshift, fftfreq, rfft, irfft, rfftfreq
import skimage
from skimage.util import img_as_float
from skimage.data import camera, gravel, checkerboard
from skimage.filters import difference_of_gaussians, window
import matplotlib
from matplotlib import cm
import matplotlib.pyplot as plt
#plt.style.use('seaborn-poster')
%matplotlib inline
print('numpy: {}'.format(np.__version__))
print('scipy: {}'.format(scipy.__version__))
print('skimage: {}'.format(skimage.__version__))
print('matplotlib: {}'.format(matplotlib.__version__))
numpy: 1.20.1 scipy: 1.6.2 skimage: 0.18.1 matplotlib: 3.3.4
General comments:
fft takes longer if npnts is not a multiple of 2 and even longer if it is a prime
The amplitude spectrum (power spectrum) is generated by plotting 2*np.abs(X)/npnts, instead of simply np.abs(X), versus frequency. This normalization is performed to get the actual amplitude values of the sine waves that we used to generate the signal (3.0, 1.0, and 0.5 at the three frequencies respectively)
The frequency axis can be defined in different ways:
Method 1 (Frequency mirroring above Nyquist):\
Manually define the frequency axis in this range, [0, Sampling Rate]. The FFT Amplitude of the signal now has frequency peaks that are mirrored in the [Nyquist, Sampling Rate] too. Ignore the mirrored peaks and only consider this range, [0 to Nyquist frequency]. Recall that Nyquist frequency is half of the Sampling Rate.
Method 2 (Frequency mirroring about zero):\
Define the frequency axis using fftfreq which centers the output of fft(). The frequency axis now extends from [-Nyquist, Nyquist] and the peaks are mirrored in the negative axis. Again, ignore the negative frequency and only consider this range, [0 to Nyquist frequency]
Method 3 (No mirroring-- only positive frequencies):\
The frequency spectrum that fft() outputted was reflected about the y-axis, which is caused by inputting real numbers instead of complex numbers to fft(). This symmetry can be taken advantage of to make Fourier transform faster by computing only the positive frequencies, which is what rfft() does.
def fft_powerspectrum_plot(x, t, dt, npnts, freq_show, method=2, denoise=False, cutoff=0):
'''
- plot original signal
- compute FFT and power spectrum (power per frequency)
- plot power spectrum
- plot recovered signal after fft-> ifft
x: 1D signal
t: time vector
dt: sampling interval
npnts: number of time points
freq_show: upper range of frequency to plot
method: 1/2/3
denoise: filter out noise
cutoff: denoise below the cutoff amplitude
Call: plot_signal_fftamplitude(x, t, dt, npnts, 200, method=1, denoise=True, cutoff=3)
'''
if method == 1:
X = fft(x)
n = np.arange(npnts)
T = npnts*dt
freq = n/T
title = 'Frequency mirroring above Nyquist'
elif method == 2:
X = fft(x)
freq = fftfreq(npnts, dt)
title = 'Frequency mirroring about zero'
elif method == 3:
X = rfft(x)
freq = rfftfreq(npnts, dt)
title = 'No mirroring: only positive frequencies'
powerspect = 2*np.abs(X)/npnts
if denoise:
powerspect = powerspect * (powerspect > cutoff) # Zero all frequencies with small power
X = X * (powerspect > cutoff) # Zero small Fourier coefficients
#To further zero out a peak at zero frequency -- occurs if noise is 'rand' instead of 'randn'
#powerspect = powerspect * (powerspect < 10)
#X = X * (powerspect < 10)
# PLOT
plt.figure(figsize = (18, 12))
plt.subplot(311)
plt.plot(t, x, 'k',label='original')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.subplot(312)
plt.title(title)
#plt.stem(freq, np.abs(X),'c', markerfmt=" ", basefmt="-b")
plt.stem(freq, powerspect,'c', markerfmt=" ", basefmt="-b")
plt.xlabel('Freq (Hz)')
plt.ylabel('FFT Amplitude')
if method == 2:
plt.xlim(-freq_show,freq_show)
else:
plt.xlim(0,freq_show)
plt.subplot(313)
if method == 3:
plt.plot(t, irfft(X), 'k--',label='recovered')
else:
plt.plot(t, ifft(X), 'k--',label='recovered')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.tight_layout()
plt.show()
print('Number of time points:', npnts, 'points')
print('Number of points in frequency range:', len(freq), 'points')
print('Frequency range:', min(freq), max(freq), 'Hz')
srate = 1000 # Sampling rate in Hz
dt = 1/srate
duration = 1 # in seconds
t = np.arange(0,duration,dt) # Time Vector in seconds
npnts = len(t)
# 1D signal
freq1 = 10.0
x = 2*np.sin(2*np.pi*freq1*t)
freq2 = 70.0
x += 4*np.cos(2*np.pi*freq2*t)
freq3 = 170.0
x += 6*np.cos(2*np.pi*freq3*t+np.pi/4)
noise = 6.0*np.random.randn(npnts)
x += noise
fft_powerspectrum_plot(x, t, dt, npnts, 200, 1)
/anaconda3/envs/fourier/lib/python3.9/site-packages/numpy/core/_asarray.py:102: ComplexWarning: Casting complex values to real discards the imaginary part return array(a, dtype, copy=False, order=order)
Number of time points: 1000 points Number of points in frequency range: 1000 points Frequency range: 0.0 999.0 Hz
fft_powerspectrum_plot(x, t, dt, npnts, 200, 1, True, 1)
/anaconda3/envs/fourier/lib/python3.9/site-packages/numpy/core/_asarray.py:102: ComplexWarning: Casting complex values to real discards the imaginary part return array(a, dtype, copy=False, order=order)
Number of time points: 1000 points Number of points in frequency range: 1000 points Frequency range: 0.0 999.0 Hz
fft_powerspectrum_plot(x, t, dt, npnts, 200, 2, True, 1)
/anaconda3/envs/fourier/lib/python3.9/site-packages/numpy/core/_asarray.py:102: ComplexWarning: Casting complex values to real discards the imaginary part return array(a, dtype, copy=False, order=order)
Number of time points: 1000 points Number of points in frequency range: 1000 points Frequency range: -500.0 499.0 Hz
fft_powerspectrum_plot(x, t, dt, npnts, 200, 3, True, 1)
Number of time points: 1000 points Number of points in frequency range: 501 points Frequency range: 0.0 500.0 Hz
Let's change the frequencies of the signal to non-integer values. When we do so, we see that the FFT amplitude shows up with spurious frequency peaks
srate = 1000 # Sampling rate in Hz
dt = 1/srate
duration = 1 # in seconds
t = np.arange(0,duration,dt) # Time Vector in seconds
npnts = len(t)
# 1D signal
freq1 = 10.1
x = 2*np.sin(2*np.pi*freq1*t)
freq2 = 70.5
x += 4*np.cos(2*np.pi*freq2*t)
freq3 = 170.7
x += 6* np.cos(2*np.pi*freq3*t+np.pi/4)
noise = 6.0*np.random.randn(npnts)
x += noise
fft_powerspectrum_plot(x, t, dt, npnts, 200)
/anaconda3/envs/fourier/lib/python3.9/site-packages/numpy/core/_asarray.py:102: ComplexWarning: Casting complex values to real discards the imaginary part return array(a, dtype, copy=False, order=order)
Number of time points: 1000 points Number of points in frequency range: 1000 points Frequency range: -500.0 499.0 Hz
One way to mitigate these additional frequency components is to use a longer signal, as shown below:
srate = 1000 # Sampling rate in Hz
dt = 1/srate
duration = 10 # in seconds
t = np.arange(0,duration,dt) # Time Vector in seconds
npnts = len(t)
# 1D signal
freq1 = 10.1
x = 2*np.sin(2*np.pi*freq1*t)
freq2 = 70.5
x += 4*np.cos(2*np.pi*freq2*t)
freq3 = 170.7
x += 6*np.cos(2*np.pi*freq3*t+np.pi/4)
noise = 6.0*np.random.randn(npnts)
x += noise
fft_powerspectrum_plot(x, t, dt, npnts, 200)
/anaconda3/envs/fourier/lib/python3.9/site-packages/numpy/core/_asarray.py:102: ComplexWarning: Casting complex values to real discards the imaginary part return array(a, dtype, copy=False, order=order)
Number of time points: 10000 points Number of points in frequency range: 10000 points Frequency range: -500.0 499.90000000000003 Hz
Adapted from the following numpy example
tmp = scipy.misc.face(gray=True) #Racoon face from SciPy
image = tmp[128:640, 256:768]
image.shape
(512, 512)
plt.figure(figsize = (12, 6))
plt.subplot(121)
plt.title('Original image')
plt.axis('off')
plt.imshow(image,cmap='gray');
# Take the 2-dimensional FFT and center the frequencies
ftimage = fftn(image)
ftimage = fftshift(ftimage)
plt.subplot(122)
#plt.imshow(np.abs(ftimage))
plt.imshow(np.log(np.abs(ftimage)), cmap='gray')
plt.title('FT(image)')
plt.axis('off')
plt.show()
# Build and apply a Gaussian filter
sigmax, sigmay = 7, 13
size = image.shape[0]
cy, cx = size/2, size/2
x = np.linspace(0, size, size)
y = np.linspace(0, size, size)
X, Y = np.meshgrid(x, y)
gmask = np.exp(-(((X-cx)/sigmax)**2 + ((Y-cy)/sigmay)**2))
#Note: mask is in Fourier space; same size as FT(image)
plt.figure(figsize = (12, 6))
plt.subplot(131)
plt.imshow(gmask, cmap='gray')
plt.title('Gaussian mask')
ftimagep = ftimage * gmask
plt.subplot(132)
plt.imshow(np.log(np.abs(ftimagep)),cmap='gray')
plt.title('FT(image)*mask')
plt.axis('off')
# Finally, take the inverse transform and show the blurred image
imagep = np.fft.ifftn(ftimagep)
plt.subplot(133)
plt.imshow(np.abs(imagep), cmap='gray')
plt.title('Blurred image')
plt.axis('off')
plt.show()
<ipython-input-12-7a6c3643cdc0>:35: RuntimeWarning: divide by zero encountered in log plt.imshow(np.log(np.abs(ftimagep)),cmap='gray')
fftconvolve¶#From https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.fftconvolve.html
window_x = signal.windows.gaussian(101, std=7)
window_y = signal.windows.gaussian(101, std=13)
plt.plot(window_x, label= "$\sigma_x$=7")
plt.plot(window_y, label= "$\sigma_y$=13")
plt.legend(loc="upper right")
plt.title("Gaussian window")
plt.ylabel("Amplitude")
plt.xlabel("Sample")
kernel = np.outer(window_x, window_y)
#Note: Kernel is in the same coordinate as the image; Its size is smaller compared to image
plt.figure(figsize = (12, 6))
plt.subplot(121)
plt.imshow(kernel, cmap='gray')
plt.title('Gaussian Kernel')
blurred = signal.fftconvolve(image, kernel, mode='same')
plt.subplot(122)
plt.imshow(blurred, cmap='gray')
plt.title('Blurred image')
plt.axis('off')
plt.show()
Notice the dark borders around the image, due to zero-padding beyond its boundaries.
scipy.ndimage.gaussian_filter gets rid of this artifact.
# Image Denoising: Gaussian filter, Median filter
noisy = image + 0.4*image.std()*np.random.random(image.shape)
gauss_denoised = ndimage.gaussian_filter(noisy, sigma = 11)
med_denoised = ndimage.median_filter(noisy, size = 9)
plt.figure(figsize=(12,4))
plt.subplot(131)
plt.imshow(noisy, cmap=plt.cm.gray, vmin=40, vmax=220)
plt.axis('off')
plt.title('noisy', fontsize=20)
plt.subplot(132)
plt.imshow(gauss_denoised, cmap=plt.cm.gray, vmin=40, vmax=220)
plt.axis('off')
plt.title('Gaussian filter', fontsize=20)
plt.subplot(133)
plt.imshow(med_denoised, cmap=plt.cm.gray, vmin=40, vmax=220)
plt.axis('off')
plt.title('Median filter', fontsize=20)
plt.subplots_adjust(wspace=0.02, hspace=0.02, top=0.9, bottom=0, left=0, right=1)
plt.show()
Note that a Gaussian filter smooths out the noise and the edges. Median filter averages the noise and preserves the edges better
Adapted from the following SciKit-Image documentation
Band-pass filters attenuate signal frequencies outside of a range (band) of interest. In image analysis, they can be used to denoise images while at the same time reducing low-frequency artifacts such a uneven illumination. Band-pass filters can be used to find image features such as blobs and edges.
One method for applying band-pass filters to images is to subtract an image blurred with a Gaussian kernel from a less-blurred image. This example shows two applications of the Difference of Gaussians approach for band-pass filtering.
More on DoG here
image = gravel()
wimage = image * window('hann', image.shape) # window image to improve FFT
filtered_image = difference_of_gaussians(image, 1, 12)
filtered_wimage = filtered_image * window('hann', image.shape)
im_f_mag = fftshift(np.abs(fftn(wimage)))
fim_f_mag = fftshift(np.abs(fftn(filtered_wimage)))
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(11, 12))
ax[0, 0].imshow(image, cmap='gray')
ax[0, 0].set_title('Original Image')
ax[0, 0].axis('off')
ax[0, 1].imshow(np.log(im_f_mag), cmap='magma')
ax[0, 1].set_title('FFT Magnitude (log)')
ax[0, 1].axis('off')
ax[1, 0].imshow(filtered_image, cmap='gray')
ax[1, 0].set_title('Filtered Image')
ax[1, 0].axis('off')
ax[1, 1].imshow(np.log(fim_f_mag), cmap='magma')
ax[1, 1].set_title('FFT Magnitude (log)')
ax[1, 1].axis('off')
plt.show()
image = camera()
#image = noisy
wimage = image * window('hann', image.shape) # window image to improve FFT
filtered_image = difference_of_gaussians(image, 1.5)
filtered_wimage = filtered_image * window('hann', image.shape)
im_f_mag = fftshift(np.abs(fftn(wimage)))
fim_f_mag = fftshift(np.abs(fftn(filtered_wimage)))
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(11, 12))
ax[0, 0].imshow(image, cmap='gray')
ax[0, 0].set_title('Original Image')
ax[0, 0].axis('off')
ax[0, 1].imshow(np.log(im_f_mag), cmap='magma')
ax[0, 1].set_title('Original FFT Magnitude (log)')
ax[0, 1].axis('off')
ax[1, 0].imshow(filtered_image, cmap='gray')
ax[1, 0].set_title('Filtered Image')
ax[1, 0].axis('off')
ax[1, 1].imshow(np.log(fim_f_mag), cmap='magma')
ax[1, 1].set_title('Filtered FFT Magnitude (log)')
ax[1, 1].axis('off')
plt.show()
#Coherent Image Simulation
img = checkerboard()
img = np.flipud(img)
M, N = img.shape
L = 0.3e-3 # image length in m
du = L/M # sample interval (m)
u = np.arange(-L/2, L/2, du)
v = np.arange(-L/2, L/2, du)
fimage = img_as_float(img)
imgfield = np.sqrt(fimage) # ideal image field
plt.figure()
plt.title('Original image')
plt.imshow(imgfield, extent=[-L/2, L/2, -L/2, L/2], cmap='gray')
plt.xlabel('u (m)')
plt.ylabel('v (m)')
plt.show()
lamda = 0.5e-6 # wavelength in m
wxp = 6.25e-3 # exit pupil radius in m
zxp = 125e-3 # exit pupil distance in m
f0 = wxp/(lamda*zxp) # cutoff frequency
#freq coords
fu = np.arange(-1/(2*du), 1/(2*du), 1/L)
fv = np.arange(-1/(2*du), 1/(2*du), 1/L)
Fu,Fv = np.meshgrid(fu,fv)
H = np.sqrt(Fu**2 + Fv**2)/f0 < 1
fig = plt.figure(figsize=(9,7))
fig.suptitle('Coherent Transfer Function')
ax = fig.add_subplot(projection='3d')
# Plot the surface
ax.plot_surface(Fu, Fv, H, cmap=cm.gray)
fig.tight_layout()
ax.get_xaxis().set_ticks([])
ax.set_xlabel('fu (cycle/m)')
ax.get_yaxis().set_ticks([])
ax.set_ylabel('fv (cycle/m)')
ax.get_zaxis().set_ticks([])
plt.show()
H = fftshift(H)
Gg = fftn(fftshift(imgfield))
Gi = np.multiply(Gg, H)
ui = ifftshift(ifftn(Gi))
Ii = (np.abs(ui))**2
#Simulated diffraction-limited coherent image
plt.figure()
plt.title('Simulated image')
plt.imshow(Ii, extent=[-L/2, L/2, -L/2, L/2], cmap='gray')
plt.xlabel('u (m)')
plt.ylabel('v (m)')
plt.show()